import sys
sys.path.append('..')
sys.path.append('../../src')
import numpy as np
import pickle
import torch
import torch.nn as nn
from global_var import *
from normalize import *
from data_load import *
from utils import *
from AE import AutoEncoder
from VAE import VAE
from  Whisper import Whisper
import ExtBound as ExtBound
import KITree as KITree
import importlib
import utils as Utils
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score
from utils import evaluate_and_save_results_Hyperparameters

importlib.reload(KITree)
models = ["VAE","OCSVM","IForest","AE"]
# models = ["AE","VAE","Whisper","OCSVM"]
# models = ["OCSVM"]
# models = ["AE","IForest"]
models = ['IForest']
# models = ['VAE']
datasets = {
'KDD':[
       'kddcup'
    ],
 'toniot_custom': [
         'dos',
         'mitm',
    'runsomware',


],


'cicids_custom':[

'Friday',
    # 'Monday',
    'Thursday',
    'Tuesday'
],
    'Web':[
        'web'
    ]
}

max_levels = [8]
n_beams = [2, 4, 6, 8, 10]
rhos = [0.01, 0.05, 0.1, 0.5, 1]
etas = [4, 6,8, 10, 12]
split_ratios = [5, 10, 15, 20, 25,30]



def loadData(dataset, subset):
    train_data, test_data, train_target, test_target = load_data(dataset, subset, 'train')
    test_data, test_target = load_data(dataset, subset, mode='test')
    with open(os.path.join(NORMALIZER_DIR, f'{dataset}_{subset}.norm'), 'rb') as f:
        normalizer = pickle.load(f)

    train_data = normalizer.transform(train_data)
    test_data = normalizer.transform(test_data)
    return train_data, test_data, train_target, test_target

def perturb_data_point(data_point, delta):
    return data_point + delta

max_level_ = 8


for model in models:
    for dataset in datasets:

        all_predictions = np.empty((0,))
        all_original_predictions = np.empty((0,))
        all_perturbed_predictions = np.empty((0,))
        all_y_test = np.empty((0,))

        for subset in datasets[dataset]:

            X_train, X_test, y_train, y_test = loadData(dataset, subset)
            X, y = X_train, y_train

            if model == "VAE" or model == "AE":
                blackbox_model = torch.load(os.path.join(TARGET_MODEL_DIR, f'{model}_{dataset}_{subset}.model'),
                                            map_location='cuda:0').cuda(DEVICE)
                blackbox_model.eval()
                func_ = lambda x: blackbox_model.score_samples(x)
                score_ = func_(X)
                thres_ = blackbox_model.thres
            else:
                with open(os.path.join(TARGET_MODEL_DIR, f'{model}_{dataset}_{subset}.model'), 'rb') as f:
                    blackbox_model = pickle.load(f)

                if model == "Whisper":
                    score_ = blackbox_model.score_samples(X)
                    thres_ = blackbox_model.threshold
                    func_ = lambda x: -blackbox_model.score_samples(x)
                else:
                    score_ = -blackbox_model.decision_function(X)
                    thres_ = -blackbox_model.offset_
                    func_ = lambda x: -blackbox_model.decision_function(x)

            kdt_ = KITree.KITree(func_, thres_)
            import time

            start = time.time()
            kdt_.fit(X, y,score_)
            end = time.time()
            avg_train_time = (end - start) / X.shape[0]

            start = time.time()
            predictions = kdt_.predict(X_test)
            end = time.time()
            avg_pred_time = (end - start) / X_test.shape[0]

            all_predictions = np.concatenate((all_predictions, predictions))
            with open(os.path.join(NORMALIZER_DIR, f'{dataset}_{subset}.norm'), 'rb') as f:
                normalizer = pickle.load(f)

            original_predictions = blackbox_model.predict(X_test)
            if model == "IForest" or model == "OCSVM" or model == "Whisper":
                original_predictions = np.where(original_predictions == 1, 0, original_predictions)
                original_predictions = np.where(original_predictions == -1, 1, original_predictions)
            all_original_predictions = np.concatenate((all_original_predictions, original_predictions))


            def perturb_data_point(data_point, delta):
                # 化为列标签，扰动数据
                return np.array([x + delta for x in data_point]).reshape(1, -1)


            delta = 0.001
            robustness_sum = 0
            perturbed_predictions = np.empty((0,))
            for i, data_point in enumerate(X_test):
                perturbed_data_point = perturb_data_point(data_point, delta)
                perturbed_prediction = kdt_.predict(perturbed_data_point)

                if np.array_equal(predictions[i], perturbed_prediction):
                    # 鲁棒性
                    robustness_sum += 1
                perturbed_predictions = np.append(perturbed_predictions, perturbed_prediction)
            all_perturbed_predictions = np.concatenate((all_perturbed_predictions, perturbed_predictions))
            all_y_test = np.concatenate((all_y_test, y_test))

        print("================ Processing ( {dataset} ) dataset, using model = ( {model} )".format(dataset=dataset,
                                                                                                    model=model),
              "================")

        all_y_test = all_y_test.astype(int)
        all_predictions = all_predictions.astype(int)
        all_original_predictions = all_original_predictions.astype(int)
        all_perturbed_predictions = all_perturbed_predictions.astype(int)
        evaluate_and_save_results_Hyperparameters(all_y_test, all_predictions, all_original_predictions,
                                                  all_perturbed_predictions, dataset, baseline="ours_exam",
                                                  black_model=model,avg_train_time=avg_train_time,avg_pred_time=avg_pred_time)

print("============================= max_level All works are finished! =============================")